## master preamble ------------------------------------------------------
rm(list=ls(all=TRUE)) 

## inherit directories from stata -----------
args = commandArgs(trailingOnly = "TRUE")
if (length(args)) {
  dir.proj = paste0(args[1],'/')
  dir.lib  = args[2]
} 

## if loading from library -------------------
if(dir.lib!='') {
  ## load depenencies --------------------
  library(R.utils,lib.loc=dir.lib)
  library(ggplot2,lib.loc=dir.lib)
  library(tidyr,lib.loc=dir.lib)
  library(stringr,lib.loc=dir.lib)
  library(forcats,lib.loc=dir.lib)
  library(lubridate,lib.loc=dir.lib)
  library(haven,lib.loc=dir.lib)
  library(data.table,lib.loc=dir.lib)
  library(backports,lib.loc=dir.lib)
  ## load main packages ------------------
  library(rio,lib.loc=dir.lib)
  library(dplyr,lib.loc=dir.lib)
  library(tidyverse,lib.loc=dir.lib)
  library(fixest,lib.loc=dir.lib)
  library(broom,lib.loc=dir.lib)
## otherwise: simple load -----------------
} else {
  library(rio)
  library(dplyr)
  library(tidyverse)
  library(fixest)
  library(broom)
}
## --------------------------------------------------------------------

## file preamble ------------------------------------------------------
## randomization seed --------
set.seed(413)

## directories --------
dir.data = paste0(dir.proj,'data/out/')
dir.out  = paste0(dir.proj,'estimates/')

## file-specific -------
name.Y   = 'cite_ny1'
name.out = 'history_q2'

## parameters --------
Q = 2 
type = 'joint'
bootrep  = 250

## functions ----------------
rudirichlet <- function(n, d) {
  rexp.mat <- matrix( rexp(d * n, 1) , nrow = n, ncol = d)
  dirichlet.weights <- rexp.mat / rowSums(rexp.mat)
  dirichlet.weights }
## ---------------------------------------------------------------------


## Setup data -------------------
data = import(paste0(dir.data,'4-main.dta'))
data = mutate(data,Y=data[[name.Y]],D=harsh,W=covbin1,jid=officerid) %>%
  select(citationid,Y,D,Z,W,totfe,jid,countynum,troop,starts_with('lenient'),cite_py2)

## G = heterogeneity group ------
data = mutate(
  data,
  G=as.numeric(as.factor(cite_py2)),
  Gfe=as.factor(G))
maxG = max(data$G)


## Polynomials in Z ---
data = mutate(
  data,
  Z1=Z^1,Z2=Z^2,Z3=Z^3,Z4=Z^4,Z5=Z^5,Z6=Z^6,Z7=Z^7,Z8=Z^8)

## Specification for FE ------
if(type=='sep') {
  fe.use = 'totfe_G' }
if(type=='joint') {
  fe.use = c('totfe','Gfe') }

## Setup for clustered bootstrap -------------------
list    = unique(select(data,jid))
weights = rudirichlet(nrow(list),bootrep)*10


## Looping -----------
for(j in 0:bootrep) {
  
  ## Setup bootstrap data --------
  if(j==0) {
    boot.list = list %>% mutate(bootwt = 1)
    boot.data = inner_join(data,boot.list,by='jid') }
  if(j>=1) {
    boot.list = list %>% mutate(bootwt = weights[,j])
    boot.data = inner_join(data,boot.list,by='jid') }
  
  ## Get group weights -----------
  mu = boot.data %>% group_by(G) %>% summarize(N=sum(bootwt)) %>%
    mutate(wt = N/sum(N)) %>% select(c(G,wt))
  
  ## Get sample means ------------
  mu = inner_join(
    mu,
    boot.data %>% group_by(G) %>% summarize(
      Y = weighted.mean(Y,bootwt),
      p = weighted.mean(D,bootwt)),by='G')
  
  ## Adjusted harshness --------
  fit = feols(D~i(G,keep=seq(2,maxG))|totfe,boot.data,weights=~bootwt)
  mu = mutate(
    mu,
    padj = as.numeric(c(0,fit$coefficients))+mean(predict(fit,newdata=boot.data,fixef=T)$totfe))
  
  ## Lenient Means -----------
  fit = fixest::feols(
    as.formula(glue::glue(
      'Y~i(G,lenientpart)|',paste(fe.use,collapse='+'))),
    boot.data,weights=~bootwt)
  fe = bind_cols(
    boot.data %>% select(G),
    stats::predict(fit,newdata=boot.data,fixef=T))
  if(type=='joint') {
    fe = mutate(fe,sumfe = totfe+Gfe) }
  if(type=='sep') {
    fe = mutate(fe,sumfe=totfe_G) }
  fe = fe %>% group_by(G) %>% summarize(fe=mean(sumfe,na.rm=T))
  mu = mutate(
    mu,y0len = fe$fe + fit$coefficients)
  
  ## Group-specific LATE -----------
  fit = fixest::feols(
    as.formula(glue::glue(
      'Y~0|',paste(fe.use,collapse='+'),'|i(G,D)~i(G,Z)')),
    boot.data,weights=~bootwt)
  mu = mutate(
    mu,late=fit$coefficients)
  
  ## Get other pars ------------
  fit  = feols(Y~0+D*i(G)-D,boot.data,weights=~bootwt)
  coef = broom::tidy(fit) %>%
    separate(term,sep='::',into=c('stub1','stub2')) %>%
    separate(stub2,sep=':',into=c('G','D'),convert=T,fill='right')
  mu = inner_join(
    mu, 
    inner_join(
      coef %>% filter(is.na(D)) %>% select(G,y0d0=estimate),
      coef %>% filter(!is.na(D)) %>% select(G,delta=estimate),by='G') %>% mutate(y1d1=y0d0+delta) %>%
      select(G,y0d0,y1d1),by='G')
  
  ## Single polynomial order (set above) -----
  fit = fixest::feols(
    as.formula(glue::glue('Y~i(G)*(',paste(paste0('Z',seq(1:Q)),collapse='+'),')',
                          '-i(G)-(',paste(paste0('Z',seq(1:Q)),collapse='+'),')|',
                          paste(fe.use,collapse='+'))),
    boot.data,weights=~bootwt)
  
  ## Means of FE --------
  fe = bind_cols(
    boot.data %>% select(G),
    stats::predict(fit,newdata=boot.data,fixef=T))
  if(type=='joint') {
    fe = mutate(fe,sumfe = totfe+Gfe) }
  if(type=='sep') {
    fe = mutate(fe,sumfe=totfe_G) }
  fe = fe %>% group_by(G) %>% summarize(fe=mean(sumfe,na.rm=T))
  
  ## Coefficients -------
  coef = broom::tidy(fit) %>%
    mutate(temp = stringr::str_remove(term,'G::')) %>%
    separate(temp,sep=':Z',into=c('G','q')) %>%
    mutate(G = as.numeric(G)) %>% 
    group_by(G) %>% summarize(sum=sum(estimate))
  
  ## Compile stuff --------
  #est = mu %>% 
  #  left_join(fe,by='G') %>% left_join(coef,by='G') %>% 
  #  mutate(
  #    y0=fe,y1=fe+sum,ate=(y1-y0),att=(Y-y0)/p,atu=(y1-Y)/(1-p),
  #    y1d0=(1/(1-p))*y1-(p/(1-p))*y1d1,y0d1=(1/p)*y0-((1-p)/p)*y0d0,
  #    y0d1my0d0=y0d1-y0d0,attmatu=att-atu) %>% select(-c(fe,sum))
  
  ## Compile stuff (with reformat) --------
  est = mu %>% 
    left_join(fe,by='G') %>% left_join(coef,by='G') %>% 
    mutate(
      y0=fe,y1=fe+sum,ate=(y1-y0),att=(Y-y0)/p,atu=(y1-Y)/(1-p),
      y1d0=(1/(1-p))*y1-(p/(1-p))*y1d1,y0d1=(1/p)*y0-((1-p)/p)*y0d0,
      y0d1my0d0=y0d1-y0d0,attmatu=att-atu) %>% select(-c(fe,sum)) %>%
    pivot_longer(!G,names_to='par',values_to='est') %>% as.data.frame()
  
  ## Post output ------------
  if(j==0) {
    boot.out = est %>% mutate(iter=j) }
  else { 
    boot.out = bind_rows(boot.out,est %>% mutate(iter=j)) }
  
  
  ## Progress report -----
  print(paste('Iteration',j,'of',bootrep,'Completed'))
  
}
## End of bootstrap -------------------------

## Bootstrap output: parameters -------
est.main = boot.out %>% filter(iter==0) %>% select(-c(iter))
se.main  = boot.out %>% filter(iter>=1) %>% group_by(G,par) %>%
  summarize(se=sd(est),lq=quantile(est,0.05),uq=quantile(est,0.95))
est.out = inner_join(est.main,se.main) %>% mutate(lb=est-1.96*se,ub=est+1.96*se)
export(est.out,paste0(dir.out,name.out,'.csv'))


